Skip to content

[Model] Add Ming-omni-tts dense 0.5B pipeline#2906

Open
akshatvishu wants to merge 67 commits into
vllm-project:mainfrom
akshatvishu:feat/ming-omni-tts-dense
Open

[Model] Add Ming-omni-tts dense 0.5B pipeline#2906
akshatvishu wants to merge 67 commits into
vllm-project:mainfrom
akshatvishu:feat/ming-omni-tts-dense

Conversation

@akshatvishu
Copy link
Copy Markdown
Contributor

@akshatvishu akshatvishu commented Apr 18, 2026

Purpose

Add Ming-omni-tts dense 0.5B support to vLLM-Omni via a two-stage AR+Flow → Audio VAE pipeline.

Original repo : https://github.com/inclusionAI/Ming-omni-tts

Resolves:
#1461

Changes:
Model files (vllm_omni/model_executor/models/ming_tts/)

  • ming_tts.py — top-level two-stage dispatcher and weight-loading entry point
  • ming_tts_llm.py — Stage-0 Qwen2 AR backbone with inline Aggregator, FlowLoss, stop head, and latent patch emission
  • ming_tts_audio_vae.py — Stage-1 Audio VAE decoder producing 44.1 kHz mono waveform output
  • config_ming_tts.py — Ming dense constants, runtime keys, latent sizes, token IDs, stop-head defaults, and sample-rate validation
  • configuration_ming_dense.py — Hugging Face config adapter for inclusionAI/Ming-omni-tts-0.5B
  • prompt_builder.py — prompt construction for speech, music, instructions, TTA, prompt waveform, and speaker embeddings
  • ingress.py — first-stage prompt ingestion for the disaggregated pipeline
  • speaker_extractor.py — CampPlus 192-d speaker embedding extraction for reference audio
  • fm/ — Flow Matching modules used by Stage-0 latent generation
  • audio_tokenizer/ — Ming Audio VAE tokenizer and decoder support modules

Registry

  • Register MingTTSForConditionalGeneration, MingLLMModel, and MingAudioVAEModel in vllm_omni/model_executor/models/registry.py

Stage config & input processors

  • vllm_omni/model_executor/stage_configs/ming_tts.yaml — sequential two-stage AR+Flow → Audio VAE pipeline
  • vllm_omni/model_executor/stage_configs/ming_tts_async_chunk.yaml — async chunk pipeline with SharedMemoryConnector, latent_chunk_size: 25, and max_num_seqs: 1
  • vllm_omni/model_executor/stage_input_processors/ming_tts.py — Stage-0 → Stage-1 latent patch transfer for llm2audio_vae and llm2audio_vae_async_chunk, including final partial chunk flush

Offline examples

  • examples/offline_inference/ming_tts/end2end.py — end-to-end Omni example covering 11 cookbook cases: style, ip, bgm, tta, emotion, basic, dialect, zero_shot, podcast, speech_bgm, speech_sound
  • examples/offline_inference/ming_tts/README.md — offline launch notes for sequential and async chunk runs

Online serving

  • vllm_omni/entrypoints/openai/serving_speech.py — Ming prompt builder for OpenAI-compatible /v1/audio/speech, with structured instructions, voice → IP, language → dialect, reference audio, 192-d speaker embeddings, podcast multi-speaker conditioning, and streaming PCM/WAV output
  • examples/online_serving/ming_tts/run_server.sh — async chunk server launch script
  • examples/online_serving/ming_tts/openai_speech_client.py — API client covering Ming controls and streaming output
  • examples/online_serving/ming_tts/run_curl.sh — curl examples for /v1/audio/speech
  • examples/online_serving/ming_tts/README.md and docs/user_guide/examples/online_serving/ming_tts.md — online serving documentation

Architecture:

Stage 0: Qwen2ForCausalLM + Aggregator + FlowLoss → latent audio patches
Stage 1: Ming Audio VAE → 44.1 kHz mono waveform

Known limitations / follow-ups:

  • Online /v1/audio/speech does not yet expose prompt_mode=music/tta or FlowLoss controls (cfg, sigma, temperature); online BGM and TTA require a future prompt-mode API extension.
  • Stage configs use max_num_seqs: 1; multi-request batching is not yet validated.
  • latent_chunk_size: 5 improves online TTFP significantly but diverges on podcast in the offline async matrix; repo YAML stays on the validated latent_chunk_size: 25 default until that is resolved.

Test Plan

Validation was performed on an NVIDIA L4 GPU (Colab).

Offline sequential — full 11-case cookbook matrix :

python examples/offline_inference/ming_tts/end2end.py --case <case>

Offline async_chunk — full 11-case cookbook matrix:

python examples/offline_inference/ming_tts/end2end.py \
    --case <case> \
    --streaming \
    --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts_async_chunk.yaml \
    --enforce-eager

Online serving/v1/audio/speech async_chunk checks:

# Start server
vllm-omni serve inclusionAI/Ming-omni-tts-0.5B \
    --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts_async_chunk.yaml \
    --omni --enforce-eager

# Run client checks (style, ip, basic, emotion, dialect, zero_shot, podcast,
# speech_bgm, speech_sound, streaming PCM, ref_audio, speaker_embedding, podcast multi-ref)

Test Result

Offline correctness — sequential vs. async_chunk (latent_chunk_size: 25):

All 11 cases produced identical frame counts and Stage-1 total patch counts between sequential and default async_chunk, confirming correct Stage-0 → Stage-1 handoff and final partial chunk flush.

Case Frames / Patches / Audio (s) Seq = Async25
style 409248 / 29 / 9.28
ip 183456 / 13 / 4.16
bgm 1326528 / 94 / 30.08
tta 465696 / 33 / 10.56
emotion 324576 / 23 / 7.36
basic 211680 / 15 / 4.80
dialect 239904 / 17 / 5.44
zero_shot 409248 / 29 / 9.28
podcast 437472 / 31 / 9.92
speech_bgm 296352 / 21 / 6.72
speech_sound 352800 / 25 / 8.00

Upstream FlashAttention comparison (cold, single-request, L4):

Upstream: torch 2.6.0+cu124, FlashAttention 2.7.4.post1. vLLM-Omni VAE stage runs through SDPA, not upstream FlashAttention. Integration comparison, not kernel parity benchmark.

Case Upstream RTF vLLM Seq RTF vLLM Async25 RTF
style 0.704 1.026 0.709
ip 0.695 0.978 1.045
bgm 0.692 0.611 0.571
emotion 0.688 0.823 0.830
basic 0.689 0.918 0.917
dialect 0.684 0.869 0.915
zero_shot 0.692 0.754 0.684
podcast 0.697 0.735 0.676
speech_bgm 0.687 0.823 0.820
speech_sound 0.681 0.772 0.808
(avg, 10 cases) 0.691 0.831 0.798

vLLM-Omni matches or beats upstream RTF on bgm; async25 is near-parity on style, zero_shot, and podcast. Cold single-request numbers include engine startup and first-request lazy setup costs.

Warm-cache RTF vs upstream (L4, post-warmup, 1 warmup + 1 measured request):

Warm-cache removes first-request lazy setup. Fairer per-request comparison against upstream.

Case Upstream RTF vLLM Seq RTF (warm) vLLM Async25 RTF (warm) Seq delta vs upstream
style 0.704 0.563 0.549 -20.0%
ip 0.695 0.565 0.569 -18.7%
bgm 0.692 0.571 0.520 -17.5%
emotion 0.688 0.548 0.571 -20.3%
basic 0.689 0.572 0.602 -17.0%
dialect 0.684 0.569 0.557 -16.8%
zero_shot 0.692 0.565 0.491 -18.4%
podcast 0.697 0.559 0.521 -19.8%
speech_bgm 0.687 0.570 0.574 -17.0%
speech_sound 0.681 0.555 0.568 -18.5%
(avg, 10 cases) 0.691 0.564 0.552 -18.4%

Warm vLLM-Omni sequential beats upstream FlashAttention RTF across all 10 measured cases. Async25 further reduces RTF for longer/reference-conditioned cases and the zero-ref style / bgm runs.

Warm-cache offline benchmark (L4, 1 warmup + 1 measured request):

Case Seq wall / RTF Async25 wall / RTF / TTFP Delta
style 4.864s / 0.563 4.743s / 0.549 / 4.675s -2.5%
ip 2.169s / 0.565 2.185s / 0.569 / 2.181s +0.7%
bgm 17.174s / 0.571 15.632s / 0.520 / 4.819s -9.0%
zero_shot 5.425s / 0.565 4.711s / 0.491 / 4.515s -13.2%
podcast 5.547s / 0.559 5.164s / 0.521 / 4.740s -6.9%
tta 5.830s / 0.552 5.571s / 0.528 / 4.752s -4.4%
basic 3.112s / 0.572 3.277s / 0.602 / 3.273s +5.3%

Async chunk benefits longer/reference-conditioned cases; overhead roughly cancels the overlap benefit for short speech cases.

Online serving benchmark (10 prompts, concurrency 1, eager, L4):

Config Mean TTFP (ms) Mean E2E (ms) Mean RTF
sequential_eager 3354.83 3357.01 0.561
async_chunk_eager (chunk=25) 3450.28 3452.35 0.577
async_chunk_bench (chunk=5) 911.20 2985.04 0.499

latent_chunk_size: 5 reduces mean TTFP by ~73% and E2E by ~11% vs. sequential, but remains experimental pending podcast offline finalization.

Online /v1/audio/speech validation (async_chunk, all speech-mode cases):

All cases returned valid WAV at 44.1 kHz. Streaming PCM returned progressive chunks. Reference audio, speaker embedding, and podcast multi-reference checks passed.

Case Output Size (bytes) Sample rate Frames
style WAV 790316 44100 395136
ip WAV 366956 44100 183456
basic WAV 536300 44100 268128
emotion WAV 649196 44100 324576
dialect WAV 395180 44100 197568
zero_shot WAV 931436 44100 465696
podcast WAV 846764 44100 423360
speech_bgm WAV 677420 44100 338688
speech_sound WAV 649196 44100 324576
streaming PCM 338688 N/A N/A

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@akshatvishu
Copy link
Copy Markdown
Contributor Author

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. Credits must be used to enable repository wide code reviews.

I guess everyone is suffering under the new limits (╥_╥)

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is marked as [WIP] and is substantial (~10,500 lines / 47 files).

Could you please run the L3 tests locally and paste the results here? This helps validate the integration on your end before we proceed with full review.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make changes accordingly after #2383 merged. For the model usage, I suggest to write a model recipe under vllm_omni/recipes using the template. It seems there are some duplicate/dead codes as well, can you try to compress it first?

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

I also recommend you to use the add-tts-models skill

Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the thorough test matrix and the warm-cache RTF numbers, those are the right kind of evidence for a model-add PR. At 10.5k additions the PR is hard to review carefully. I think it can stay as one PR if we condense it by reusing modules that already live in the repo. Inline comments below on the specific files, ordered roughly by expected line savings.

Not blocking merge, flagging for the author and maintainers.

Comment thread vllm_omni/model_executor/models/ming_tts/ming_tts_llm.py
@@ -0,0 +1,207 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MAJOR] cosyvoice3/code2wav_core/cfm.py (325 lines) already implements Conditional Flow Matching. This PR adds fm/cfm.py (207) plus fm/modules.py (147), for roughly 350 lines of duplicated logic.

Suggestion: promote the cosyvoice3 CFM plus a DiT base to vllm_omni/model_executor/modules/flow_matching/, have Ming import it, and keep only fm/dit.py (Ming-specific conditioning) and fm/flowloss.py here.

This is a cross-model refactor, fine to land as a prerequisite PR owned by a maintainer or cc @yuanheng-zhao rather than blocking Ming on it. Worth an issue link from the PR body at minimum.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I’ll file/link a follow-up issue for promoting the shared CFM/DiT base into vllm_omni/model_executor/modules/flow_matching/, unless you prefer this to be a prerequisite PR before Ming lands.

Comment thread vllm_omni/model_executor/stage_input_processors/ming_tts.py
@@ -0,0 +1,188 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MINOR] Pure math. qwen3_tts/tokenizer_25hz/ and voxtral_tts/voxtral_tts_audio_tokenizer.py also ship an iSTFT. Recommend opening a follow-up issue to migrate all three to a shared vllm_omni/model_executor/modules/audio/stft.py. Not a blocker on this PR, but please file the issue so this does not go cold.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I’ll file/link a follow-up issue to migrate the repeated iSTFT implementations into a shared vllm_omni/model_executor/modules/audio/stft.py.

Comment thread vllm_omni/model_executor/models/ming_tts/speaker_extractor.py
Comment thread vllm_omni/model_executor/models/ming_tts/ming_tts.py
Comment thread vllm_omni/model_executor/models/ming_tts/config_ming_tts.py
Comment thread vllm_omni/model_executor/stage_configs/ming_tts_async_chunk.yaml Outdated
Comment thread examples/offline_inference/text_to_speech/ming_tts/end2end.py
Comment thread examples/online_serving/text_to_speech/ming_tts/run_curl.sh
@yuanheng-zhao
Copy link
Copy Markdown
Collaborator

@akshatvishu It seems there're a lot added files that could re-use modules from the talker of Ming-flash-omni-2.0 in #2890 , especially modelings such as talker llm, talker vae, fm, spkemb extractor.

I'll update #2890 later today and try to merge it ASAP and then you might want to rebase

cc @linyueqian @hsliuustc0106

@akshatvishu
Copy link
Copy Markdown
Contributor Author

@yuanheng-zhao Sure, I will wait for #2890 to get merge and will then start working on the suggestion left by @linyueqian as it seems like I can borrow a lot from Ming-flash-omni-2.0; after that I will run and upload the results of L3 test as requested by @hsliuustc0106

@yuanheng-zhao
Copy link
Copy Markdown
Collaborator

Hey @akshatvishu , the Ming-flash-omni-2.0 talker (modelings of that model for TTS & Omni-Speech) has been merged to main, let's rebase onto main with cutting off from the talker stage changes. For example,

git rebase --onto main the-talker-branch your-current-branch

Note to fetch and have latest main and my branch on your local

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
@akshatvishu akshatvishu force-pushed the feat/ming-omni-tts-dense branch from d949ec7 to 9add4ef Compare April 23, 2026 14:01
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
…s signature

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
…tecture

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
…to-detection fails

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>

# Conflicts:
#	vllm_omni/engine/async_omni_engine.py
#	vllm_omni/entrypoints/openai/serving_speech.py
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
   Offline: promote AsyncOmni to a module-scoped fixture so the streaming
   test shares the engine init with the other three tests instead of
   paying
   a fresh two-stage load each run (~30 min → ~15 min on L4). Also cleans
   up the inline try/finally that the fixture teardown now handles.

   Online: replace four-level Path(__file__).parent chain with
   get_deploy_config_path("ming_tts.yaml"), matching the convention used
   by cosyvoice3 and moss_tts_nano. Drops the now-unused pathlib import.

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
@akshatvishu
Copy link
Copy Markdown
Contributor Author

Thanks for the incredibly thorough review @linyueqian !

I've just pushed a batch of commits that resolves all the [HIGH], [MEDIUM] and [LOW] architectural feedback.

Regarding the +7,300 LOC and deduplication:
I completely agree with your consolidation plan. Instead of opening a tracking issue and letting the tech debt linger, I want to tackle the Audio VAE, CFM/DiT, docs and helper de-duplications right here in this PR.

Since extracting the common components into ming_utils/ means touching the existing ming_flash_omni architecture, I want to be careful not to break its CFMGraphExecutor.

I'm working on this refactor now. Give me a day or two to move the shared logic, test it locally against both models and trim down the docs. I'll ping you for a re-review once it's pushed!

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
@akshatvishu
Copy link
Copy Markdown
Contributor Author

akshatvishu commented May 31, 2026

Hi @linyueqian , I split out the shared Ming modules as suggested but while validating the change I hit existing Ming Flash Omni failures on ROCm.

Tested with the official ROCm Docker image on an MI300X x8 node, provided through the AMD Developer Cloud program. Thanks to the AMD developer program team for granting access to the node.

https://github.com/vllm-project/vllm-omni/blob/main/tests/e2e/offline_inference/test_ming_flash_omni_expansion.py

All expansion tests initially failed (compatibility issues in the existing Ming Flash Omni path against the current vLLM/vLLM-Omni + transformers stack). The main failures were:

ENV:

`python collect_env.py` ```bash ============================== System Info ============================== OS : Ubuntu 22.04.5 LTS (x86_64) GCC version : (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0 Clang version : 22.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-7.2.2 26084 f58b06dce1f9c15707c5f808fd002e18c2accf7e) CMake version : version 3.31.10 Libc version : glibc-2.35

==============================
PyTorch Info

PyTorch version : 2.10.0+git8514f05
Is debug build : False
CUDA used to build PyTorch : N/A
ROCM used to build PyTorch : 7.2.53211

==============================
Python Environment

Python version : 3.12.13 (main, Mar 4 2026, 09:23:07) [GCC 11.4.0] (64-bit runtime)
Python platform : Linux-6.8.0-106-generic-x86_64-with-glibc2.35

==============================
CUDA / GPU Info

Is CUDA available : True
CUDA runtime version : Could not collect
CUDA_MODULE_LOADING set to :
GPU models and configuration : (gfx942:sramecc+:xnack-)
Nvidia driver version : Could not collect
cuDNN version : Could not collect
HIP runtime version : 7.2.53211
MIOpen runtime version : 3.5.1
Is XNNPACK available : True

==============================
CPU Info

Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 160
On-line CPU(s) list: 0-159
Vendor ID: GenuineIntel
Model name: INTEL(R) XEON(R) PLATINUM 8568Y+
CPU family: 6
Model: 207
Thread(s) per core: 1
Core(s) per socket: 80
Socket(s): 2
Stepping: 2
BogoMIPS: 4600.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq dtes64 vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities
Virtualization: VT-x
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 5 MiB (160 instances)
L1i cache: 5 MiB (160 instances)
L2 cache: 640 MiB (160 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-79
NUMA node1 CPU(s): 80-159
Vulnerability Gather data sampling: Not affected
Vulnerability Indirect target selection: Vulnerable
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Unknown: No mitigations
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsa: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Vulnerability Vmscape: Not affected

==============================
Versions of relevant libraries

[pip3] conch-triton-kernels==1.2.1
[pip3] mypy==1.11.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.1.3
[pip3] onnx==1.19.0
[pip3] onnx-ir==0.2.1
[pip3] onnxruntime-rocm==1.22.2.post1
[pip3] onnxscript==0.7.0
[pip3] onnxslim==0.1.94
[pip3] pyzmq==27.1.0
[pip3] torch==2.10.0+git8514f05
[pip3] torch_c_dlpack_ext==0.1.5
[pip3] torch-complex==0.4.4
[pip3] torch-einops-utils==0.1.1
[pip3] torchaudio==2.9.0+eaa9e4e
[pip3] torchcodec==0.13.0
[pip3] torchdiffeq==0.2.5
[pip3] torchmetrics==1.9.0
[pip3] torchsde==0.2.6
[pip3] torchvision==0.24.1+d801a34
[pip3] transformers==5.8.1
[pip3] triton==3.6.0
[pip3] triton_kernels==1.0.0
[pip3] x-transformers==2.19.7
[conda] Could not collect

==============================
vLLM Info

ROCM Version : 7.2.53211-35e8c7bf89
vLLM Version : 0.22.0
vLLM-Omni Version : 0.1.dev1818+gbc49be130.rocm (git sha: bc49be1, date: ocm)
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled
GPU Topology:
============================ ROCm System Management Interface ============================
================================ Weight between two GPUs =================================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 15 15 15 15 15 15 15
GPU1 15 0 15 15 15 15 15 15
GPU2 15 15 0 15 15 15 15 15
GPU3 15 15 15 0 15 15 15 15
GPU4 15 15 15 15 0 15 15 15
GPU5 15 15 15 15 15 0 15 15
GPU6 15 15 15 15 15 15 0 15
GPU7 15 15 15 15 15 15 15 0

================================= Hops between two GPUs ==================================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 1 1 1 1 1 1 1
GPU1 1 0 1 1 1 1 1 1
GPU2 1 1 0 1 1 1 1 1
GPU3 1 1 1 0 1 1 1 1
GPU4 1 1 1 1 0 1 1 1
GPU5 1 1 1 1 1 0 1 1
GPU6 1 1 1 1 1 1 0 1
GPU7 1 1 1 1 1 1 1 0

=============================== Link Type between two GPUs ===============================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 XGMI XGMI XGMI XGMI XGMI XGMI XGMI
GPU1 XGMI 0 XGMI XGMI XGMI XGMI XGMI XGMI
GPU2 XGMI XGMI 0 XGMI XGMI XGMI XGMI XGMI
GPU3 XGMI XGMI XGMI 0 XGMI XGMI XGMI XGMI
GPU4 XGMI XGMI XGMI XGMI 0 XGMI XGMI XGMI
GPU5 XGMI XGMI XGMI XGMI XGMI 0 XGMI XGMI
GPU6 XGMI XGMI XGMI XGMI XGMI XGMI 0 XGMI
GPU7 XGMI XGMI XGMI XGMI XGMI XGMI XGMI 0

======================================= Numa Nodes =======================================
GPU[0] : (Topology) Numa Node: 0
GPU[0] : (Topology) Numa Affinity: 0
GPU[1] : (Topology) Numa Node: 0
GPU[1] : (Topology) Numa Affinity: 0
GPU[2] : (Topology) Numa Node: 0
GPU[2] : (Topology) Numa Affinity: 0
GPU[3] : (Topology) Numa Node: 0
GPU[3] : (Topology) Numa Affinity: 0
GPU[4] : (Topology) Numa Node: 1
GPU[4] : (Topology) Numa Affinity: 1
GPU[5] : (Topology) Numa Node: 1
GPU[5] : (Topology) Numa Affinity: 1
GPU[6] : (Topology) Numa Node: 1
GPU[6] : (Topology) Numa Affinity: 1
GPU[7] : (Topology) Numa Node: 1
GPU[7] : (Topology) Numa Affinity: 1
================================== End of ROCm SMI Log ===================================

==============================
Environment Variables

VLLM_WORKER_MULTIPROC_METHOD=spawn
VLLM_ROCM_USE_AITER=0
PYTORCH_ROCM_ARCH=gfx942
LD_LIBRARY_PATH=/usr/local/lib/python3.12/dist-packages/cv2/../../lib64:/opt/rocm/lib:/usr/local/lib:
VLLM_LOGGING_LEVEL=DEBUG
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_root

Failure msg:

=================================================================================== short test summary info ====================================================================================
ERROR tests/e2e/offline_inference/test_ming_flash_omni_expansion.py::test_text_to_text[omni_runner0] - RuntimeError: Orchestrator initialization failed: StageEngineCoreProc died during READY (exit code 1)
ERROR tests/e2e/offline_inference/test_ming_flash_omni_expansion.py::test_image_to_text[omni_runner0] - RuntimeError: Orchestrator initialization failed: StageEngineCoreProc died during READY (exit code 1)
ERROR tests/e2e/offline_inference/test_ming_flash_omni_expansion.py::test_audio_to_text[omni_runner0] - RuntimeError: Orchestrator initialization failed: StageEngineCoreProc died during READY (exit code 1)
ERROR tests/e2e/offline_inference/test_ming_flash_omni_expansion.py::test_video_to_text[omni_runner0] - RuntimeError: Orchestrator initialization failed: StageEngineCoreProc died during READY (exit code 1)
ERROR tests/e2e/offline_inference/test_ming_flash_omni_expansion.py::test_mixed_to_text[omni_runner0] - RuntimeError: Orchestrator initialization failed: StageEngineCoreProc died during READY (exit code 1)
ERROR tests/e2e/offline_inference/test_ming_flash_omni_expansion.py::test_text_to_audio[omni_runner0] - RuntimeError: Orchestrator initialization failed: StageEngineCoreProc died during READY (exit code 1)
ERROR tests/e2e/offline_inference/test_ming_flash_omni_expansion.py::test_image_to_audio[omni_runner0] - RuntimeError: Orchestrator initialization failed: StageEngineCoreProc died during READY (exit code 1)

When I dig deeper it was due to a combination of transformer and vllm versions for most part :

test_image_to_text / test_mixed_to_text
- TypeError: Qwen2VLImageProcessorKwargs.__init__() got an unexpected keyword argument 'videos'

test_video_to_text
- TypeError from routing video through image_processor instead of video_processor

text/image/video to text paths
- RuntimeError: vllm::rocm_unquantized_gemm() expected Optional[Tensor] for bias but got SamplingMetadata
- Unrecognized keys in `rope_parameters` for 'rope_type'='default': {'mrope_section'}

test_text_to_audio / test_image_to_audio
- KeyError: 0 in ming_flash_omni.thinker2talker -> _validate_stage_inputs()

I have fixes for these in the follow-up commits:

  9862a46f Split Ming TTS prompt helpers by responsibility
  b9ea5552 Refactor Ming shared AudioVAE and CFM utilities
  1eabd876 Remove redundant Ming DIT checks
  149b0c09 Fix Ming Flash Omni transformer compatibility
  55c1b124 Fix Ming Flash Omni talker input bridge

To keep the current PR focused on the requested LOC reduction, I plan to keep only:

  98364f10 Simplify Ming TTS documentation links

in the current PR and move the shared-module + Ming Flash Omni ROCm fixes into a separate PR.

Does that split sound right? Or It would be okay to review these here itself!

…dense

# Conflicts:
#	tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py
#	vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py
#	vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py
@akshatvishu
Copy link
Copy Markdown
Contributor Author

Hi @yuanheng-zhao !

Incase the pytest failures also reproduce on CUDA with the latest main, you try cherry-picking these two commits and re-running the test:

git cherry-pick 149b0c09 55c1b124
149b0c09 Fix Ming Flash Omni transformer compatibility
55c1b124 Fix Ming Flash Omni talker input bridge

They seem to resolve the issue on the ROCm side.

Please let me know if you’d prefer a different approach. I’m happy to make the changes.

@yuanheng-zhao
Copy link
Copy Markdown
Collaborator

Hi @yuanheng-zhao !

Incase the pytest failures also reproduce on CUDA with the latest main, you try cherry-picking these two commits and re-running the test:

git cherry-pick 149b0c09 55c1b124
149b0c09 Fix Ming Flash Omni transformer compatibility
55c1b124 Fix Ming Flash Omni talker input bridge

They seem to resolve the issue on the ROCm side.

Please let me know if you’d prefer a different approach. I’m happy to make the changes.

Sure, I think I could get back within 12 hrs. Btw, may I have the versions of vllm, vllm-omni, and transformers for your env for reference? Thanks.

@akshatvishu
Copy link
Copy Markdown
Contributor Author

akshatvishu commented Jun 1, 2026

Btw, may I have the versions of vllm, vllm-omni, and transformers for your env for reference? Thanks.

transformers==5.8.1, vLLM Version : 0.22.0,

vLLM-Omni Version : 0.1.dev1818+gbc49be130.rocm (git sha: bc49be1, date: ocm)

Also, the output of python collect_environment.py is at #2906 (comment) in-case you wanna check for any other dependencies!

I also tested this yesterday with vllm==0.21 and transformer version==4.57.6 and all the pytest still failed on ROCm!

Also, the sampling_metadata parameter become outdated from vllm after this change:vllm-project/vllm@1c3ffdb

P.S. I also pinged you on the vLLM Slack in case it’s easier to follow up there.

@yuanheng-zhao
Copy link
Copy Markdown
Collaborator

Hey @akshatvishu , I did reproduce errors with both transformers 4.57 and 5.5X. I tried your branch but got another error about split_routed_experts, which I thought should be resolved by the recent rebase. Have you encountered this one? And could you try to update your branch with main and resolve the conflicts

…dense

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>

# Conflicts:
#	vllm_omni/entrypoints/openai/serving_speech.py
#	vllm_omni/worker/gpu_ar_model_runner.py
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Adapt the Ming Flash Omni talker compatibility fixes suggested in PR vllm-project#4080.

Suggested-by: Yuanheng Zhao <jonathan.zhaoyh@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Comment on lines +1232 to +1238
if deploy_config_path is not None:
_deploy_path = Path(deploy_config_path)
if _deploy_path.exists():
_deploy_cfg = load_deploy_config(_deploy_path)
if _deploy_cfg.pipeline and _deploy_cfg.pipeline in _PIPELINE_REGISTRY:
return cls._create_from_registry(_deploy_cfg.pipeline, cli_overrides, deploy_config_path)

Copy link
Copy Markdown
Collaborator

@yuanheng-zhao yuanheng-zhao Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PTAL @alex-jw-brooks at the changes related with stage configs, do we currently have ongoing logics to handle this?

Suggested-by: Yuanheng Zhao <jonathan.zhaoyh@gmail.com>
Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
Copy link
Copy Markdown

@Nightwing-77 Nightwing-77 Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the Ming model utils defined as a separate model executor? Could there be a better place for it? We should keep it aligned with the repository structure.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be more ideal?
From: vllm_omni/model_executor/models/ming_utils/
To: vllm_omni/model_executor/models/common/ming/

Copy link
Copy Markdown

@Nightwing-77 Nightwing-77 Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be more ideal? From: vllm_omni/model_executor/models/ming_utils/ To: vllm_omni/model_executor/models/common/ming/

yea, probably!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion @Nightwing-77 ! Just pushed the commit for the same.

Signed-off-by: akshatvishu <akshatnayak197@gmail.com>
@akshatvishu akshatvishu force-pushed the feat/ming-omni-tts-dense branch from 3b8d741 to 34d13eb Compare June 3, 2026 18:47
description="Language code (e.g., 'Chinese', 'English', 'Auto')",
)
ref_audio: str | None = Field(
ref_audio: str | list[str] | None = Field(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does ming TTS support multiple ref audio!?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it support stuff like podcast/multi-speaker TTS scenarios! You can check https://github.com/inclusionAI/Ming-omni-tts/blob/94a4d409/cookbooks/cookbook.ipynb#L192-L196 for more info!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants